#pragma once
#include <math.h>
#include <memory>

#ifndef PI_F 
#define PI_F (3.1415927f)
#endif
#ifndef DEG2RAD
#define DEG2RAD(deg) (PI_F * (deg) / 180.0f)
#endif
#ifndef RAD2DEG
#define RAD2DEG(rad) (180.0f * (rad) / PI_F)
#endif

namespace DXBase
{
	// --- (WinRT) structs

	public value struct float2
	{
		float x,y;
	};
	public value struct float3
	{
		float x,y,z;
	};
	public value struct float4
	{
		float x,y,z,w;
	};
	public value struct quaternion
	{
		float x,y,z,w;
	};

	public value struct float4x4
	{
		float m00; float m01; float m02; float m03;
		float m10; float m11; float m12; float m13;
		float m20; float m21; float m22; float m23;
		float m30; float m31; float m32; float m33;
	};
	public value struct float3x2
	{
		float m11, m12;
		float m21, m22;
		float m31, m32;
	};

	public value struct USize { uint32 width; uint32 height; };
	public value struct Range
	{
		uint32 Start;
		uint32 Length;
	};

	public value struct box3d { float x, y, z, width, height, depth; };

	// pack all color in a single 32bit struct matching DXGI_FORMAT_B8G8R8A8_UNORM
	public value struct color32
	{
		byte b,g,r,a;
	};


	// --- functions to work with that stuff

	inline bool operator ==(const color32& s1, const color32& s2)
	{
		uint32 *p1 = (uint32*)&s1, *p2 = (uint32*)&s2;
		return *p1 == *p2; 
	}


	inline bool operator ==(const USize& s1, const USize& s2) { return s1.width == s2.width && s1.height == s2.height; }
	inline bool operator !=(const USize& s1, const USize& s2) { return s1.width != s2.width || s1.height != s2.height; }

	inline bool operator==(const box3d& a, const box3d& b) 
	{
		return 
			a.x == b.x && a.width == b.width
			&& a.y == b.y && a.height == b.height
			&& a.z && b.z && a.depth == b.depth
			; 
	}

	float2 vector2(float x, float y);
	float3 vector3(float x, float y, float z);
	float4 vector4(float x, float y, float z, float w = 1);
	float4 vector4(float3& v, float w);

	inline float& at(float2& v, unsigned int index) { return ((float*)(&v))[index]; }
	inline float& at(float3& v, unsigned int index) { return ((float*)(&v))[index]; }
	inline float& at(float4& v, unsigned int index) { return ((float*)(&v))[index]; }

	inline float dot(float2& a, float2& b)  { return a.x * b.x + a.y * b.y; }
	inline float dot(float3& a, float3& b)  { return a.x * b.x + a.y * b.y + a.z * b.z; }
	inline float dot(float4& a, float4& b)  { return a.x * b.x + a.y * b.y + a.z * b.z + a.w + b.w; }

	inline float length(float2& a) { return sqrt(a.x * a.x + a.y * a.y); }
	inline float length(float3& a) { return sqrt(a.x * a.x + a.y * a.y + a.z * a.z); }
	inline float length(float4& a) { return sqrt(a.x * a.x + a.y * a.y + a.z * a.z + a.w * a.w); }

	inline float3 cross(float3& a, float3& b) { return vector3((a.y*b.z)-(a.z*b.y),(a.z*b.x)-(a.x*b.z),(a.x*b.y)-(a.y*b.x)); }

	inline float2 operator+(float2& a, float2& b) { return vector2(a.x + b.x, a.y + b.y); }
	inline float2 operator-(float2& a, float2& b) { return vector2(a.x - b.x, a.y - b.y); }
	inline float2 operator-(float2& a) { return vector2(-a.x, -a.y); }
	inline float2 operator*(float2& a, float2& b) { return vector2(a.x * b.x, a.y * b.y); }
	inline float2 operator*(float2& a, float s) { return vector2(a.x * s, a.y * s); }
	inline float2 operator*(float s, float2& a) { return a * s; }
	inline float2 operator/(float2& a, float s) { return vector2(a.x / s, a.y / s); }
		 
	inline float3 operator+(float3& a, float3& b) { return vector3(a.x + b.x, a.y + b.y, a.z + b.z); }
	inline float3 operator-(float3& a, float3& b) { return vector3(a.x - b.x, a.y - b.y, a.z - b.z); }
	inline float3 operator-(float3& a) { return vector3(-a.x, -a.y, -a.z); }
	inline float3 operator*(float3& a, float3& b) { return vector3(a.x * b.x, a.y * b.y, a.z * b.z); }
	inline float3 operator*(float3& a, float s) { return vector3(a.x * s, a.y * s, a.z * s); }
	inline float3 operator*(float s, float3& a) { return a * s; }
	inline float3 operator/(float3& a, float s) { return vector3(a.x / s, a.y / s, a.z / s); }
		 
	inline float4 operator+(float4& a, float4& b) { return vector4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); }
	inline float4 operator-(float4& a, float4& b) { return vector4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); }
	inline float4 operator-(float4& a) { return vector4(-a.x, -a.y, -a.z, -a.w); }
	inline float4 operator*(float4& a, float4& b) { return vector4(a.x * b.x, a.y * b.y, a.w * b.z, a.w * b.w); }
	inline float4 operator*(float4& a, float s) { return vector4(a.x * s, a.y * s, a.z * s, a.w * s); }
	inline float4 operator*(float s, float4& a) { return a * s; }
	inline float4 operator/(float4& a, float s) { return vector4(a.x / s, a.y / s, a.z / s, a.w / s); }

	inline float2 normalize(float2& a) { return a / length(a); }
	inline float3 normalize(float3& a) { return a / length(a); }
	inline float4 normalize(float4& a) { return a / length(a); }

	inline bool operator==(const float2& a, const float2& b) { return a.x == b.x && a.y == b.y; }
	inline bool operator==(const float3& a, const float3& b) { return a.x == b.x && a.y == b.y && a.z == b.z; }
	inline bool operator==(const float4& a, const float4& b) { return a.x == b.x && a.y == b.y && a.z == b.z && a.w == b.w; }
	inline bool operator==(const quaternion& a, const quaternion& b) 
	{
		int sig = a.w * b.w > 0 ? 1 : -1;
		return a.x == sig * b.x && a.y == sig * b.y && a.z == sig * b.z && a.w == sig * b.w; 
	}

	float2 lerp(const float2& a, const float2& b, float t);
	float3 lerp(const float3& a, const float3& b, float t);
	float4 lerp(const float4& a, const float4& b, float t);

	// Template Matrix Operations
	inline float& at(float4x4& a, unsigned int row, unsigned col)  { return ((float*)&a)[row*4+col]; }

	float4x4 matrix4x4(float value = 0);
	float4x4 matrix4x4(float3& vx, float3& vy, float3& vz);
	float4x4 matrix4x4(float3& vx, float3& vy, float3& vz, float3& origin);
	float4x4 matrix4x4(
		float i11, float i12, float i13, float i14,
		float i21, float i22, float i23, float i24,
		float i31, float i32, float i33, float i34,
		float i41, float i42, float i43, float i44
		);

	float4x4 transpose(float4x4& m);
	float4x4 mul(float4x4& m1,float4x4& m2);
	float4x4 identity();
	float4x4 translation(float x, float y, float z);
	float4x4 translation(float3& v);
	float4x4 scale(float x, float y, float z);
	float4x4 scale(float3& v);
	float4x4 rotationX(float degreeX);
	float4x4 rotationY(float degreeY);
	float4x4 rotationZ(float degreeZ);
	float4x4 rotation(float4& rotation);
	float4x4 rotation(float3& v0, float3& v1);
	// pass v0, v1, get axis and degree
	float4 getrotation(float3& v0, float3& v1);

	inline float4x4 operator*(float4x4& a,float4x4& b) { return mul(a,b); }
	float4 operator*(float4x4& a, float4& b);

	float4x4 transpose(float4x4& m);
	float determinant(const float4x4& m);
	float4x4 invert(const float4x4& m);

	bool operator==(const float4x4& a, const float4x4& b);


	bool decompose(float4x4& m, float3& translation, float4& rot, float3& scale);
	bool decompose(float4x4& m, float3& translation, quaternion& q, float3& scale);
	quaternion getquaternion(float3& v0, float3& v1);

	inline quaternion cquaternion(float x, float y, float z, float w) 
	{
		quaternion q;
		q.x = x;
		q.y = y;
		q.z = z;
		q.w = w;
		return q;
	}
	inline quaternion cquaternion(float3 v, float w) 
	{
		quaternion q;
		q.x = v.x;
		q.y = v.y;
		q.z = v.z;
		q.w = w;
		return q;
	}

	//http://www.idevgames.com/articles/quaternions
	//http://willperone.net/Code/quaternion.php
	class QuaternionOp
	{
	public:
		static float4x4 Rotation(quaternion q);
		static quaternion Extract(float4x4 m);

		static quaternion Normalize(quaternion q);

		static quaternion FromRotation(float4 rot);
		static float4 ToRotation(quaternion q);

		static quaternion Invert(quaternion q);
		static quaternion Mul(quaternion q1, quaternion q2);
		static float3 Mul(quaternion q, float3 p);

		static quaternion Slerp(quaternion q1, quaternion q2, float alpha);
	};


	template <class T>
	struct SquareMatrix
	{
		int size;
		T *data; // had strange trouble with std::vector<float> copy constructor
		SquareMatrix(int n) : size(n)
		{
			data = new T[n * n];
			memset(data, 0, n * n * sizeof(T));
		}
		SquareMatrix(SquareMatrix<T>& other)
			: size(other.size), data(other.data)
		{
			other.size = 0;
			other.data = NULL;
		}
		~SquareMatrix() 
		{
			if (data)
				delete[] data;
			data = NULL;
		}
		T& at(int i, int j) const { return data[i * size + j]; };

		void Transpose()
		{
			for(int i=0; i<size-1; i++)
				for(int j=i+1; j<size; j++)
				{
					auto tmp = at(i, j);
					at(i,j)  = at(j,i);
					at(j,i)  = tmp;
				}
		}

		SquareMatrix<T> Minor(int i, int j) const
		{
			if (size == 1)
				return *const_cast<SquareMatrix<T>*>(this);

			SquareMatrix<T> result(size - 1);
			for (int iSrc = 0; iSrc < size; iSrc++)
				for (int jSrc = 0; jSrc < size; jSrc++)
				{
					if (iSrc == i) continue;
					if (jSrc == j) continue;
					int iDst = iSrc<i ? iSrc : iSrc-1;
					int jDst = jSrc<j ? jSrc : jSrc-1;
					result.at(iDst,jDst) = at(iSrc,jSrc);
				}
			return result;
		}
		T Determinant() const
		{
			if(size == 1)
				return data[0];
			float result = 0;
			for (int i=0; i<size; i++)
			{
				auto minor = Minor(0, i);
				auto det = minor.Determinant();
				result += (i%2==0 ? 1 : -1) * at(0,i) * det;
			}
			return result;
		}
		SquareMatrix<T> Adjoint() const
		{
			SquareMatrix<T> result(size);
			for (int i = 0; i < size; i++)
				for (int j = 0; j < size; j++)
				{
					auto ij = i+j;
					result.at(i,j) = (ij % 2 == 0 ? 1 : -1) * Minor(i,j).Determinant();
				}
			return result;
		}
		SquareMatrix<T> Invert() const
		{
			SquareMatrix<T> result(size);
			T det = 0;
			for (int i=0; i<size; i++)
			{
				auto minor = Minor(0, i);
				auto num = (i % 2 == 0 ? 1 : -1) * minor.Determinant();
				det += at(0,i) * num;
				result.at(i,0) = num;
			}
			if (!det)
				return result;

			for (int i=0; i<size; i++)
				result.at(i,0) = result.at(i,0) / det;
			for (int i = 1; i < size; i++)
				for (int j = 0; j < size; j++)
				{
					auto ij = i+j;
					result.at(j,i) = (ij % 2 == 0 ? 1 : -1) * Minor(i,j).Determinant() / det;
				}
			return result;
		}
	};


	float3x2 matrix2d(float m11, float m12, float m21, float m22, float m31, float m32);
	float3x2 identity2d();
	float3x2 rotate2d(float degrees, const float2& center);
	inline float3x2 rotate2d(float degrees) { return rotate2d(degrees, vector2(0, 0)); }
	float3x2 scale2d(float sx, float sy, const float2& center);
	inline float3x2 scale2d(const float2& scale) { return scale2d(scale.x, scale.y, vector2(0, 0)); }
	inline float3x2 scale2d(const float2& scale, const float2& center) { return scale2d(scale.x, scale.y, center); }
	float3x2 skew2d(float degreesX, float degreesY, const float2& center);
	inline float3x2 skew2d(float degreesX, float degreesY) { return skew2d(degreesX, degreesY, vector2(0, 0)); }
	float3x2 translate2d(float dx, float dy);
	inline float3x2 translate2d(const float2& vector) { translate2d(vector.x, vector.y); }
	float3x2 mul(const float3x2& m1, const float3x2& m2);
	float2 mul(const float3x2& m1, const float2& p);
	inline float3x2 operator * (const float3x2& a, const float3x2& b) { return mul(a, b); }
	inline float2 operator * (const float3x2& a, const float2& b) { return mul(a, b); }
	float3x2 invert(const float3x2& m);



	//----- make shortcut methods (to turn into DirectX structs) 

	inline D2D1_MATRIX_3X2_F makeD2D(const float3x2& m)
	{
		D2D1_MATRIX_3X2_F result = { m.m11, m.m12, m.m21, m.m22, m.m31, m.m32 };
		return result;
	}
	inline D2D1_POINT_2F makeD2D(const float2& p)
	{
		D2D1_POINT_2F result = { p.x, p.y };
		return result;
	}
	inline D2D1_POINT_2F makeD2D(const Windows::Foundation::Point& p)
	{
		D2D1_POINT_2F result = { p.X, p.Y };
		return result;
	}
	inline D2D1_RECT_F makeD2D(const Windows::Foundation::Rect& rect)
	{
		D2D1_RECT_F result = { rect.X, rect.Y, rect.X + rect.Width, rect.Y + rect.Height };
		return result;
	}
	inline D2D1_SIZE_F makeD2D(const Windows::Foundation::Size& size)
	{
		D2D1_SIZE_F result = { size.Width, size.Height };
		return result;
	}
	inline D2D1_COLOR_F makeD2D(const Windows::UI::Color& c)
	{
		D2D1_COLOR_F result = { c.R / 255.0f, c.G / 255.0f, c.B / 255.0f, c.A / 255.0f };
		return result;
	}
	inline void makeD3D(const Windows::UI::Color& c, float values[4])
	{
		values[0] = c.R / 255.0f;
		values[1] = c.G / 255.0f;
		values[2] = c.B / 255.0f;
		values[3] = c.A / 255.0f;
	}
	inline D2D1_MATRIX_4X4_F makeD2D(const float4x4& m)
	{
		D2D1_MATRIX_4X4_F result = 
		{ 
			m.m00, m.m01, m.m02, m.m03, 
			m.m10, m.m11, m.m12, m.m13, 
			m.m20, m.m21, m.m22, m.m23, 
			m.m30, m.m31, m.m32, m.m33
		};
		return result;
	}

	inline USize makeUSize(uint32 w, uint32 h) 
	{
		USize result = { w, h };
		return result;
	}
	inline D2D_SIZE_U makeD2D(USize u) 
	{
		D2D_SIZE_U result = { u.width, u.height };
		return result;
	}

	inline DWRITE_TEXT_RANGE makeDWrite(const Range& value)
	{
		DWRITE_TEXT_RANGE result = { value.Start, value.Length };
		return result;
	}
	inline Range fromDWrite(const DWRITE_TEXT_RANGE& value)
	{
		Range result = { value.startPosition, value.length };
		return result;
	}

	inline D3D11_VIEWPORT makeD3D(const box3d& box)
	{
		D3D11_VIEWPORT result = { box.x, box.y, box.width, box.height, box.z, box.z + box.depth };
		return result;
	}
	inline box3d fromD3D(const D3D11_VIEWPORT& box)
	{
		box3d result = { box.TopLeftX, box.TopLeftY, box.MinDepth, box.Width, box.Height, box.MaxDepth - box.MinDepth };
		return result;
	}
}